#include <seminix/rbtree_augmented.h>
#include <seminix/pgtable.h>
#include <seminix/mmap.h>
#include <seminix/sysctl.h>
#include <seminix/vmacache.h>
#include <seminix/slab.h>
#include <asm/cacheflush.h>
#include <asm/tlb.h>

static void vma_pages_copy(struct page **src, struct page **dst, int nr)
{
    for (int i = 0; i < nr; i++) {
        if (dst[i])
            src[i] = dst[i];
    }
}

static int adjust_pageshift(struct vm_area_struct *vma)
{
    int page_shift = PAGE_SHIFT;

    if (is_vm_hugetlb_page(vma)) {
        page_shift = vm_hugetlb_pageshift(vma->pages->page_size);
        if (page_shift < 0)
            return page_shift;
    }
    vma->pages->page_shift = page_shift;
    return 0;
}

static int vma_pages_init(struct vm_area_struct *vma, unsigned long page_size)
{
    int ret = -ENOMEM;

    vma->pages = kmalloc(sizeof (*vma->pages), GFP_KERNEL);
    if (!vma->pages)
        return -ENOMEM;;
    vma->pages->nr_pages = (vma->vm_end - vma->vm_start) / page_size;
    vma->pages->pages = kcalloc(vma->pages->nr_pages, sizeof (struct page *), GFP_KERNEL);
    if (!vma->pages->pages)
        goto free_pages;
    vma->pages->page_size = page_size;
    ret = adjust_pageshift(vma);
    if (ret)
        goto free_array_pages;
    atomic_set(&vma->pages->pages_ref, 1);

    return 0;
free_array_pages:
    kfree(vma->pages->pages);
free_pages:
    kfree(vma->pages);
    return ret;
}

static void __maybe_unused vma_pages_get(struct vma_pages *pages)
{
    atomic_inc(&pages->pages_ref);
}

static void vma_pages_put(struct vma_pages *pages, bool free_page)
{
    if (atomic_dec_and_test(&pages->pages_ref)) {
        if (free_page) {
            int order = get_order(pages->page_size);
            for (int i = 0; i < pages->nr_pages; i++) {
                if (pages->pages[i])
                    __free_pages(pages->pages[i], order);
            }
        }
        kfree(pages->pages);
        kfree(pages);
    }
}

static pgprot_t vm_pgprot_modify(pgprot_t oldprot, unsigned long vm_flags)
{
    return pgprot_modify(oldprot, vm_get_page_prot(vm_flags));
}

/* Update vma->vm_page_prot to reflect vma->vm_flags. */
void vma_set_page_prot(struct vm_area_struct *vma)
{
    unsigned long vm_flags = vma->vm_flags;
    pgprot_t vm_page_prot;

    vm_page_prot = vm_pgprot_modify(vma->vm_page_prot, vm_flags);
    /* remove_protection_ptes reads vma->vm_page_prot without mmap_lock */
    WRITE_ONCE(vma->vm_page_prot, vm_page_prot);
}

/*
 * Some shared mappings will want the pages marked read-only
 * to track write events. If so, we'll downgrade vm_page_prot
 * to the private version (using protection_map[] without the
 * VM_SHARED bit).
 */
int vma_wants_writenotify(struct vm_area_struct *vma, pgprot_t vm_page_prot)
{
    unsigned long vm_flags = vma->vm_flags;

    /* If it was private or non-writable, the write bit is already clear */
    if ((vm_flags & (VM_WRITE|VM_SHARED)) != ((VM_WRITE|VM_SHARED)))
        return 0;

    /* The open routine did something to the protections that pgprot_modify
     * won't preserve? */
    if (pgprot_val(vm_page_prot) !=
        pgprot_val(vm_pgprot_modify(vm_page_prot, vm_flags)))
        return 0;

    return 1;
}

/*
 * Close a vm structure and free it, returning the next.
 */
static struct vm_area_struct *remove_vma(struct vm_area_struct *vma)
{
    struct vm_area_struct *next = vma->vm_next;

    vma_pages_put(vma->pages, true);
    vm_area_free(vma);
    return next;
}

static long vma_compute_subtree(struct vm_area_struct *vma)
{
    unsigned long max, prev_end, subtree;

    max = vma->vm_start;
    if (vma->vm_prev) {
        prev_end =  vma->vm_prev->vm_end;
        if (max > prev_end)
            max -= prev_end;
        else
            max = 0;
    }
    if (vma->vm_rb.rb_left) {
        subtree = rb_entry(vma->vm_rb.rb_left,
            struct vm_area_struct, vm_rb)->rb_subtree;
        if (subtree > max)
            max = subtree;
    }
    if (vma->vm_rb.rb_right) {
        subtree = rb_entry(vma->vm_rb.rb_right,
            struct vm_area_struct, vm_rb)->rb_subtree;
        if (subtree > max)
            max = subtree;
    }
    return max;
}

RB_DECLARE_CALLBACKS(static, vma_callbacks, struct vm_area_struct, vm_rb,
             unsigned long, rb_subtree, vma_compute_subtree)

static void vma_update(struct vm_area_struct *vma)
{
    /*
     * As it turns out, RB_DECLARE_CALLBACKS() already created a callback
     * function that does exacltly what we want.
     */
    vma_callbacks_propagate(&vma->vm_rb, NULL);
}

static inline void vma_rb_insert(struct vm_area_struct *vma,
                 struct rb_root *root)
{
    rb_insert_augmented(&vma->vm_rb, root, &vma_callbacks);
}

static void __vma_rb_erase(struct vm_area_struct *vma, struct rb_root *root)
{
    /*
     * Note rb_erase_augmented is a fairly large inline function,
     * so make sure we instantiate it only once with our desired
     * augmented rbtree callbacks.
     */
    rb_erase_augmented(&vma->vm_rb, root, &vma_callbacks);
}

static __always_inline void vma_rb_erase_ignore(struct vm_area_struct *vma,
                        struct rb_root *root,
                        struct vm_area_struct *ignore)
{
    __vma_rb_erase(vma, root);
}

static __always_inline void vma_rb_erase(struct vm_area_struct *vma,
                     struct rb_root *root)
{
    __vma_rb_erase(vma, root);
}

static int find_vma_links(struct mm_struct *mm, unsigned long addr,
        unsigned long end, struct vm_area_struct **pprev,
        struct rb_node ***rb_link, struct rb_node **rb_parent)
{
    struct rb_node **__rb_link, *__rb_parent, *rb_prev;

    __rb_link = &mm->mm_rb.rb_node;
    rb_prev = __rb_parent = NULL;

    while (*__rb_link) {
        struct vm_area_struct *vma_tmp;

        __rb_parent = *__rb_link;
        vma_tmp = rb_entry(__rb_parent, struct vm_area_struct, vm_rb);

        if (vma_tmp->vm_end > addr) {
            /* Fail if an existing vma overlaps the area */
            if (vma_tmp->vm_start < end)
                return -ENOMEM;
            __rb_link = &__rb_parent->rb_left;
        } else {
            rb_prev = __rb_parent;
            __rb_link = &__rb_parent->rb_right;
        }
    }

    *pprev = NULL;
    if (rb_prev)
        *pprev = rb_entry(rb_prev, struct vm_area_struct, vm_rb);
    *rb_link = __rb_link;
    *rb_parent = __rb_parent;
    return 0;
}

static void __vma_link_rb(struct mm_struct *mm, struct vm_area_struct *vma,
        struct rb_node **rb_link, struct rb_node *rb_parent)
{
    /* Update tracking information for the gap following the new vma. */
    if (vma->vm_next)
        vma_update(vma->vm_next);
    else
        mm->highest_vm_end = vma->vm_end;

    /*
     * vma->vm_prev wasn't known when we followed the rbtree to find the
     * correct insertion point for that vma. As a result, we could not
     * update the vma vm_rb parents rb_subtree_gap values on the way down.
     * So, we first insert the vma with a zero rb_subtree_gap value
     * (to be consistent with what we did on the way down), and then
     * immediately update the gap to the correct value. Finally we
     * rebalance the rbtree after all augmented values have been set.
     */
    rb_link_node(&vma->vm_rb, rb_parent, rb_link);
    vma->rb_subtree = 0;
    vma_update(vma);
    vma_rb_insert(vma, &mm->mm_rb);
}

static void __vma_link_list(struct mm_struct *mm, struct vm_area_struct *vma,
        struct vm_area_struct *prev, struct rb_node *rb_parent)
{
    struct vm_area_struct *next;

    vma->vm_prev = prev;
    if (prev) {
        next = prev->vm_next;
        prev->vm_next = vma;
    } else {
        mm->mmap = vma;
        if (rb_parent)
            next = rb_entry(rb_parent,
                    struct vm_area_struct, vm_rb);
        else
            next = NULL;
    }
    vma->vm_next = next;
    if (next)
        next->vm_prev = vma;
}

static void
__vma_link(struct mm_struct *mm, struct vm_area_struct *vma,
    struct vm_area_struct *prev, struct rb_node **rb_link,
    struct rb_node *rb_parent)
{
    __vma_link_list(mm, vma, prev, rb_parent);
    __vma_link_rb(mm, vma, rb_link, rb_parent);
}

static void vma_link(struct mm_struct *mm, struct vm_area_struct *vma,
            struct vm_area_struct *prev, struct rb_node **rb_link,
            struct rb_node *rb_parent)
{
    __vma_link(mm, vma, prev, rb_link, rb_parent);
    mm->map_count++;
}

/*
 * Helper for vma_adjust() in the split_vma insert case: insert a vma into the
 * mm's list and rbtree.  It has already been inserted into the interval tree.
 */
static void __insert_vm_struct(struct mm_struct *mm, struct vm_area_struct *vma)
{
    struct vm_area_struct *prev;
    struct rb_node **rb_link, *rb_parent;

    if (find_vma_links(mm, vma->vm_start, vma->vm_end,
               &prev, &rb_link, &rb_parent))
        BUG();
    __vma_link(mm, vma, prev, rb_link, rb_parent);
    mm->map_count++;
}

static __always_inline void __vma_unlink_common(struct mm_struct *mm,
                        struct vm_area_struct *vma,
                        struct vm_area_struct *prev,
                        bool has_prev,
                        struct vm_area_struct *ignore)
{
    struct vm_area_struct *next;

    vma_rb_erase_ignore(vma, &mm->mm_rb, ignore);
    next = vma->vm_next;
    if (has_prev)
        prev->vm_next = next;
    else {
        prev = vma->vm_prev;
        if (prev)
            prev->vm_next = next;
        else
            mm->mmap = next;
    }
    if (next)
        next->vm_prev = prev;

    /* Kill the cache */
    vmacache_invalidate(mm);
}

static inline void __vma_unlink_prev(struct mm_struct *mm,
                     struct vm_area_struct *vma,
                     struct vm_area_struct *prev)
{
    __vma_unlink_common(mm, vma, prev, true, vma);
}

/*
 * We cannot adjust vm_start, vm_end fields of a vma that
 * is already present in an i_mmap tree without adjusting the tree.
 * The following helper function should be used when such adjustments
 * are necessary.  The "insert" vma (if any) is to be inserted
 * before we drop the necessary locks.
 */
static int __vma_adjust(struct vm_area_struct *vma, unsigned long start,
    unsigned long end, struct vm_area_struct *insert, struct vm_area_struct *expand)
{
    struct mm_struct *mm = vma->vm_mm;
    struct vm_area_struct *next = vma->vm_next;
    bool start_changed = false, end_changed = false;
    long adjust_next = 0;
    int remove_next = 0;

    if (next && !insert) {
        struct vm_area_struct *importer = NULL;

        if (end >= next->vm_end) {
            /*
             * vma expands, overlapping all the next, and
             * perhaps the one after too (mprotect case 6).
             * The only other cases that gets here are
             * case 1, case 7 and case 8.
             */
            if (next == expand) {
                /*
                 * The only case where we don't expand "vma"
                 * and we expand "next" instead is case 8.
                 */
                WARN_ON(end != next->vm_end);
                /*
                 * remove_next == 3 means we're
                 * removing "vma" and that to do so we
                 * swapped "vma" and "next".
                 */
                remove_next = 3;
                swap(vma, next);
            } else {
                WARN_ON(expand != vma);
                /*
                 * case 1, 6, 7, remove_next == 2 is case 6,
                 * remove_next == 1 is case 1 or 7.
                 */
                remove_next = 1 + (end > next->vm_end);
                WARN_ON(remove_next == 2 &&
                        end != next->vm_next->vm_end);
                WARN_ON(remove_next == 1 &&
                        end != next->vm_end);
                /* trim end to next, for case 6 first pass */
                end = next->vm_end;
            }

            importer = vma;
        } else if (end > next->vm_start) {
            /*
             * vma expands, overlapping part of the next:
             * mprotect case 5 shifting the boundary up.
             */
            adjust_next = (end - next->vm_start) >> next->pages->page_shift;
            importer = vma;
            WARN_ON(expand != importer);
        } else if (end < vma->vm_end) {
            /*
             * vma shrinks, and !insert tells it's not
             * split_vma inserting another: so it must be
             * mprotect case 4 shifting the boundary down.
             */
            adjust_next = -((vma->vm_end - end) >> vma->pages->page_shift);
            importer = next;
            WARN_ON(expand != importer);
        }
    }
again:
    if (start != vma->vm_start) {
        vma->vm_start = start;
        start_changed = true;
    }
    if (end != vma->vm_end) {
        vma->vm_end = end;
        end_changed = true;
    }
    if (adjust_next) {
        next->vm_start += adjust_next << next->pages->page_shift;
    }

    if (remove_next) {
        /*
         * vma_merge has merged next into vma, and needs
         * us to remove next before dropping the locks.
         */
        if (remove_next != 3)
            __vma_unlink_prev(mm, next, vma);
        else
            /*
             * vma is not before next if they've been
             * swapped.
             *
             * pre-swap() next->vm_start was reduced so
             * tell validate_mm_rb to ignore pre-swap()
             * "next" (which is stored in post-swap()
             * "vma").
             */
            __vma_unlink_common(mm, next, NULL, false, vma);
    } else if (insert) {
        /*
         * split_vma has split insert from vma, and needs
         * us to insert it before dropping the locks
         * (it may either follow vma or precede it).
         */
        __insert_vm_struct(mm, insert);
    } else {
        if (start_changed)
            vma_update(vma);
        if (end_changed) {
            if (!next)
                mm->highest_vm_end = vma->vm_end;
            else if (!adjust_next)
                vma_update(next);
        }
    }

    if (remove_next) {
        mm->map_count--;
        remove_vma(next);
        /*
         * In mprotect's case 6 (see comments on vma_merge),
         * we must remove another next too. It would clutter
         * up the code too much to do both in one go.
         */
        if (remove_next != 3) {
            /*
             * If "next" was removed and vma->vm_end was
             * expanded (up) over it, in turn
             * "next->vm_prev->vm_end" changed and the
             * "vma->vm_next" gap must be updated.
             */
            next = vma->vm_next;
        } else {
            /*
             * For the scope of the comment "next" and
             * "vma" considered pre-swap(): if "vma" was
             * removed, next->vm_start was expanded (down)
             * over it and the "next" gap must be updated.
             * Because of the swap() the post-swap() "vma"
             * actually points to pre-swap() "next"
             * (post-swap() "next" as opposed is now a
             * dangling pointer).
             */
            next = vma;
        }
        if (remove_next == 2) {
            remove_next = 1;
            end = next->vm_end;
            goto again;
        }
        else if (next)
            vma_update(next);
        else {
            /*
             * If remove_next == 2 we obviously can't
             * reach this path.
             *
             * If remove_next == 3 we can't reach this
             * path because pre-swap() next is always not
             * NULL. pre-swap() "next" is not being
             * removed and its next->vm_end is not altered
             * (and furthermore "end" already matches
             * next->vm_end in remove_next == 3).
             *
             * We reach this only in the remove_next == 1
             * case if the "next" vma that was removed was
             * the highest vma of the mm. However in such
             * case next->vm_end == "end" and the extended
             * "vma" has vma->vm_end == next->vm_end so
             * mm->highest_vm_end doesn't need any update
             * in remove_next == 1 case.
             */
            WARN_ON(mm->highest_vm_end != vma->vm_end);
        }
    }

    return 0;
}

static inline int vma_adjust(struct vm_area_struct *vma, unsigned long start,
    unsigned long end, struct vm_area_struct *insert)
{
    return __vma_adjust(vma, start, end, insert, NULL);
}

static inline int is_mergeable_vma(struct vm_area_struct *vma, unsigned long vm_flags)
{
    if ((vma->vm_flags ^ vm_flags) & ~0)
        return 0;
    return 1;
}

/*
 * Given a mapping request (addr,end,vm_flags,file,pgoff), figure out
 * whether that can be merged with its predecessor or its successor.
 * Or both (it neatly fills a hole).
 *
 * In most cases - when called for mmap, brk or mremap - [addr,end) is
 * certain not to be mapped by the time vma_merge is called; but when
 * called for mprotect, it is certain to be already mapped (either at
 * an offset within prev, or at the start of next), and the flags of
 * this area are about to be changed to vm_flags - and the no-change
 * case has already been eliminated.
 *
 * The following mprotect cases have to be considered, where AAAA is
 * the area passed down from mprotect_fixup, never extending beyond one
 * vma, PPPPPP is the prev vma specified, and NNNNNN the next vma after:
 *
 *     AAAA             AAAA                AAAA          AAAA
 *    PPPPPPNNNNNN    PPPPPPNNNNNN    PPPPPPNNNNNN    PPPPNNNNXXXX
 *    cannot merge    might become    might become    might become
 *                    PPNNNNNNNNNN    PPPPPPPPPPNN    PPPPPPPPPPPP 6 or
 *    mmap, brk or    case 4 below    case 5 below    PPPPPPPPXXXX 7 or
 *    mremap move:                                    PPPPXXXXXXXX 8
 *        AAAA
 *    PPPP    NNNN    PPPPPPPPPPPP    PPPPPPPPNNNN    PPPPNNNNNNNN
 *    might become    case 1 below    case 2 below    case 3 below
 *
 * It is important for case 8 that the the vma NNNN overlapping the
 * region AAAA is never going to extended over XXXX. Instead XXXX must
 * be extended in region AAAA and NNNN must be removed. This way in
 * all cases where vma_merge succeeds, the moment vma_adjust drops the
 * rmap_locks, the properties of the merged vma will be already
 * correct for the whole merged range. Some of those properties like
 * vm_page_prot/vm_flags may be accessed by rmap_walks and they must
 * be correct for the whole merged range immediately after the
 * rmap_locks are released. Otherwise if XXXX would be removed and
 * NNNN would be extended over the XXXX range, remove_migration_ptes
 * or other rmap walkers (if working on addresses beyond the "end"
 * parameter) may establish ptes with the wrong permissions of NNNN
 * instead of the right permissions of XXXX.
 */
static struct vm_area_struct *vma_merge(struct mm_struct *mm,
            struct vm_area_struct *prev, unsigned long addr,
            unsigned long end, unsigned long vm_flags)
{
    struct vm_area_struct new;
    struct vm_area_struct *area, *next;
    int err, index;

    if (vm_flags & VM_SPECIAL)
        return NULL;

    if (prev)
        next = prev->vm_next;
    else
        next = mm->mmap;
    area = next;
    if (area && area->vm_end == end)		/* cases 6, 7, 8 */
        next = next->vm_next;

    /* verify some invariant that must be enforced by the caller */
    WARN_ON(prev && addr <= prev->vm_start);
    WARN_ON(area && end > area->vm_end);
    WARN_ON(addr >= end);

    /*
     * Can it merge with the predecessor?
     */
    if (prev && prev->vm_end == addr &&
            is_mergeable_vma(prev, vm_flags)) {
        /*
         * OK, it can.  Can we now merge in the successor as well?
         */
        if (next && end == next->vm_start &&
                is_mergeable_vma(next, vm_flags)) {
                            /* cases 1, 6 */
            struct vma_pages *pages = next->pages;

            new.vm_start = prev->vm_start;
            new.vm_end = next->vm_end;
            new.vm_flags = prev->vm_flags;
            index = (next->vm_start - prev->vm_start) / prev->pages->page_size;
            err = vma_pages_init(&new, prev->pages->page_size);
            if (err)
                return NULL;
            vma_pages_get(pages);
            err = __vma_adjust(prev, prev->vm_start,
                     next->vm_end, NULL, prev);
            if (err) {
                vma_pages_put(pages, true);
                vma_pages_put(new.pages, true);
                return NULL;
            }
            vma_pages_copy(new.pages->pages, prev->pages->pages, prev->pages->nr_pages);
            vma_pages_copy(new.pages->pages + index, pages->pages, prev->pages->nr_pages - index);
            vma_pages_put(pages, false);
        } else {				/* cases 2, 5, 7 */
            new.vm_start = prev->vm_start;
            new.vm_end = end;
            new.vm_flags = prev->vm_flags;
            err = vma_pages_init(&new, prev->pages->page_size);
            if (err)
                return NULL;
            err = __vma_adjust(prev, prev->vm_start,
                     end, NULL, prev);
            if (err) {
                vma_pages_put(new.pages, true);
                return NULL;
            }
            vma_pages_copy(new.pages->pages, prev->pages->pages, prev->pages->nr_pages);
            vma_pages_put(prev->pages, false);
        }
        prev->pages = new.pages;
        return prev;
    }

    /*
     * Can this new request be merged in front of next?
     */
    if (next && end == next->vm_start &&
            is_mergeable_vma(next, vm_flags)) {
        if (prev && addr < prev->vm_end) {	/* case 4 */
            /* maybe unreable */
            BUG_ON(1);
            err = __vma_adjust(prev, prev->vm_start,
                     addr, NULL, next);
        } else {					/* cases 3, 8 */
            unsigned long old_start, new_start;
            new.vm_start = addr;
            new.vm_end = next->vm_end;
            new.vm_flags = next->vm_flags;
            old_start = next->vm_start;
            err = vma_pages_init(&new, next->pages->page_size);
            if (err)
                return NULL;
            err = __vma_adjust(area, addr, next->vm_end,
                     NULL, next);
            if (err) {
                vma_pages_put(new.pages, true);
                return NULL;
            }
            new_start = next->vm_start;
            index = (old_start - new_start) / next->pages->page_size;
            vma_pages_copy(new.pages->pages + index, next->pages->pages, next->pages->nr_pages - index);
            vma_pages_put(next->pages, false);
            next->pages = new.pages;
            /*
             * In case 3 area is already equal to next and
             * this is a noop, but in case 8 "area" has
             * been removed and next was expanded over it.
             */
            area = next;
        }
        if (err)
            return NULL;
        return area;
    }

    return NULL;
}

struct vm_area_struct *__find_vma(struct task_struct *tsk, struct mm_struct *mm, unsigned long addr)
{
    struct rb_node *rb_node;
    struct vm_area_struct *vma;

    // vma = __vmacache_find(tsk, mm, addr);
    // if (likely(vma))
    //     return vma;

    rb_node = mm->mm_rb.rb_node;

    while (rb_node) {
        struct vm_area_struct *tmp;

        tmp = rb_entry(rb_node, struct vm_area_struct, vm_rb);

        if (tmp->vm_end > addr) {
            vma = tmp;
            if (tmp->vm_start <= addr)
                break;
            rb_node = rb_node->rb_left;
        } else
            rb_node = rb_node->rb_right;
    }

    // if (vma)
    //     __vmacache_update(tsk, addr, vma);
    return vma;
}

/*
 * Same as find_vma, but also return a pointer to the previous VMA in *pprev.
 */
struct vm_area_struct *
__find_vma_prev(struct task_struct *tsk, struct mm_struct *mm, unsigned long addr,
            struct vm_area_struct **pprev)
{
    struct vm_area_struct *vma;

    vma = __find_vma(tsk, mm, addr);
    if (vma)
        *pprev = vma->vm_prev;
    else {
        struct rb_node *rb_node = mm->mm_rb.rb_node;
        *pprev = NULL;
        while (rb_node) {
            *pprev = rb_entry(rb_node, struct vm_area_struct, vm_rb);
            rb_node = rb_node->rb_right;
        }
    }
    return vma;
}

static int vma_pages_two_init(struct vm_area_struct *a, struct vm_area_struct *b,
    unsigned long page_size)
{
    int err;

    err = vma_pages_init(a, page_size);
    if (err)
        return err;
    err = vma_pages_init(b, page_size);
    if (err) {
        vma_pages_put(a->pages, true);
        return err;
    }
    return 0;
}

static void __split_vma_copy(struct vma_pages *above, struct vma_pages *below, struct page **pages)
{
    int i;
    struct page *page;

    for (i = 0; i < above->nr_pages; i++) {
        page = pages[i];
        if (page)
            above->pages[i] = page;
    }
    for (; i < below->nr_pages; i++) {
        page = pages[i];
        if (page)
            below->pages[i] = page;
    }
}

static int __split_vma(struct mm_struct *mm, struct vm_area_struct *vma,
        unsigned long addr, int new_below)
{
    struct vm_area_struct *new;
    int err;
    struct vm_area_struct above, below;
    struct vma_pages *pages = vma->pages;

    BUG_ON(atomic_read(&vma->pages->pages_ref) != 1);

    if (vma->vm_flags & VM_SHARED_USER)
        return -EPERM;

    new = vm_area_dup(vma);
    if (!new)
        return -ENOMEM;

    if (new_below) {
        new->vm_end = addr;
        above.vm_start = new->vm_start;
        above.vm_end = addr;
        above.vm_flags = new->vm_flags;
        below.vm_start = addr;
        below.vm_end = vma->vm_end;
        below.vm_flags = vma->vm_end;
        err = vma_pages_two_init(&above, &below, new->pages->page_size);
        if (err)
            goto out;
        err = vma_adjust(vma, addr, vma->vm_end, new);
        if (err)
            goto free_pages;
        BUG_ON(vma->vm_start != addr);
        new->pages = above.pages;
        vma->pages = below.pages;
    } else {
        new->vm_start = addr;
        above.vm_start = vma->vm_start;
        above.vm_end = addr;
        above.vm_flags = vma->vm_flags;
        below.vm_start = addr;
        below.vm_end = new->vm_end;
        below.vm_flags = new->vm_flags;
        err = vma_pages_two_init(&above, &below, vma->pages->page_size);
        if (err)
            goto out;
        err = vma_adjust(vma, vma->vm_start, addr, new);
        if (err)
            goto free_pages;
        BUG_ON(vma->vm_end != addr);
        vma->pages = above.pages;
        new->pages = below.pages;
    }

    /* Success. */
    if (!err) {
        __split_vma_copy(above.pages, below.pages, pages->pages);
        vma_pages_put(pages, false);
        return 0;
    }

free_pages:
    vma_pages_put(above.pages, true);
    vma_pages_put(below.pages, true);
out:
    vm_area_free(new);
    return err;
}

static int split_vma(struct mm_struct *mm, struct vm_area_struct *vma,
          unsigned long addr, int new_below)
{
    // if (mm->map_count >= sysctl_max_map_count)
    //     return -ENOMEM;

    return __split_vma(mm, vma, addr, new_below);
}

/*
 * Get rid of page table information in the indicated region.
 *
 * Called with the mm semaphore held.
 */
static void unmap_region(struct mm_struct *mm,
        struct vm_area_struct *vma, struct vm_area_struct *prev,
        unsigned long start, unsigned long end)
{
    struct vm_area_struct *next = prev ? prev->vm_next : mm->mmap;
    struct mmu_gather tlb;

    tlb_gather_mmu(&tlb, mm, start, end);
    free_pgtables(&tlb, vma, prev ? prev->vm_end : FIRST_USER_ADDRESS,
                next ? next->vm_start : USER_PGTABLES_CEILING);
    tlb_finish_mmu(&tlb, start, end);
}

static void
detach_vmas_to_be_unmapped(struct mm_struct *mm, struct vm_area_struct *vma,
    struct vm_area_struct *prev, unsigned long end)
{
    struct vm_area_struct **insertion_point;
    struct vm_area_struct *tail_vma = NULL;

    insertion_point = (prev ? &prev->vm_next : &mm->mmap);
    vma->vm_prev = NULL;
    do {
        vma_rb_erase(vma, &mm->mm_rb);
        mm->map_count--;
        tail_vma = vma;
        vma = vma->vm_next;
    } while (vma && vma->vm_start < end);
    *insertion_point = vma;
    if (vma) {
        vma->vm_prev = prev;
        vma_update(vma);
    } else
        mm->highest_vm_end = prev ? prev->vm_end : 0;
    tail_vma->vm_next = NULL;

    /* Kill the cache */
    vmacache_invalidate(mm);
}

/*
 * Ok - we have the memory areas we should free on the vma list,
 * so release them, and do the vma updates.
 *
 * Called with the mm semaphore held.
 */
static void remove_vma_list(struct mm_struct *mm, struct vm_area_struct *vma)
{
    do {
        vma = remove_vma(vma);
    } while (vma);
}

static int do_munmap(struct task_struct *tsk, struct mm_struct *mm,
    unsigned long start, unsigned long len,
    bool downgrade)
{
    unsigned long end;
    struct vm_area_struct *vma, *prev, *last;

    if ((offset_in_page(start)) || start > TASK_SIZE || len > TASK_SIZE - start)
        return -EINVAL;

    len = PAGE_ALIGN(len);
    if (len == 0)
        return -EINVAL;

    /* Find the first overlapping VMA */
    vma = __find_vma(tsk, mm, start);
    if (!vma)
        return 0;
    prev = vma->vm_prev;
    /* we have  start < vma->vm_end  */

    /* if it doesn't overlap, we have nothing.. */
    end = start + len;
    if (vma->vm_start >= end)
        return 0;

    /*
     * If we need to split any vma, do it now to save pain later.
     *
     * Note: mremap's move_vma VM_ACCOUNT handling assumes a partially
     * unmapped vm_area_struct will remain in use: so lower split_vma
     * places tmp vma above, and higher split_vma places tmp vma below.
     */
    if (start > vma->vm_start) {
        int error;

        /*
         * Make sure that map_count on return from munmap() will
         * not exceed its limit; but let map_count go just above
         * its limit temporarily, to help free resources as expected.
         */
        // if (end < vma->vm_end && mm->map_count >= sysctl_max_map_count)
        //     return -ENOMEM;

        error = __split_vma(mm, vma, start, 0);
        if (error)
            return error;
        prev = vma;
    }

    /* Does it split the last one? */
    last = __find_vma(tsk, mm, end);
    if (last && end > last->vm_start) {
        int error = __split_vma(mm, last, end, 1);
        if (error)
            return error;
    }
    vma = prev ? prev->vm_next : mm->mmap;

    /* Detach vmas from rbtree */
    detach_vmas_to_be_unmapped(mm, vma, prev, end);

    if (downgrade)
        downgrade_write(&mm->mmap_sem);

    unmap_region(mm, vma, prev, start, end);

    /* Fix up all other VM information */
    remove_vma_list(mm, vma);

    return 0;
}

int __munmap(struct task_struct *tsk, struct mm_struct *mm, unsigned long addr, unsigned long len)
{
    int ret;

    if (!mm)
        return -EINVAL;

    down_write(&mm->mmap_sem);
    ret = do_munmap(tsk, mm, addr, len, false);
    up_write(&mm->mmap_sem);
    return ret;
}

static int mmap_region(struct task_struct *tsk, struct mm_struct *mm,
    unsigned long addr, unsigned len, unsigned long vm_flags,
    unsigned long pagesize)
{
    struct vm_area_struct *vma, *prev;
    struct rb_node **rb_link, *rb_parent;
    int error = 0, order;

    while (find_vma_links(mm, addr, addr + len, &prev, &rb_link,
                    &rb_parent)) {
        if (do_munmap(tsk, mm, addr, len, false))
            return -ENOMEM;
    }

    /*
     * Can we just expand an old mapping?
     */
    vma = vma_merge(mm, prev, addr, addr + len, vm_flags);
    if (vma)
        goto out;

    /*
     * Determine the object being mapped and call the appropriate
     * specific mapper. the address has already been validated, but
     * not unmapped, but the maps are removed from the list.
     */
    vma = vm_area_alloc(mm);
    if (!vma) {
        error = -ENOMEM;
        goto out_error;
    }

    vma->vm_start = addr;
    vma->vm_end = addr + len;
    vma->vm_flags = vm_flags;
    vma->vm_page_prot = vm_get_page_prot(vm_flags);
    error = vma_pages_init(vma, pagesize);
    if (error)
        goto free_vma;

    if (vma->vm_flags & VM_ATOMIC) {
        order = get_order(vma->pages->page_size);
        for (int i = 0; i < vma->pages->nr_pages; i++) {
            if (!vma->pages->pages[i]) {
                struct page *page = alloc_pages(GFP_ZERO, order);
                if (!page) {
                    error = -ENOMEM;
                    goto free_vma;
                }
                vma->pages->pages[i] = page;
            }
        }
        error = alloc_pgtables(vma);
        if (error)
            goto free_vma;
    }

    vma_link(mm, vma, prev, rb_link, rb_parent);
out:
    vma_set_page_prot(vma);

    return 0;
free_vma:
    remove_vma(vma);
out_error:
    return error;
}

static int do_mmap(struct task_struct *tsk, struct mm_struct *mm,
    unsigned long addr, unsigned long len,
    unsigned long prot, unsigned long flags)
{
    unsigned long vm_flags = calc_vm_prot_bits(prot) | calc_vm_flag_bits(flags);
    unsigned long page_size = PAGE_SIZE;

    if (!len)
        return -EINVAL;

    // if (vm_hugetlb_shift(flags)) {
    //     vm_flags |= flags & (SEMINIX_MAP_HUGE_MASK << SEMINIX_MAP_HUGE_SHIFT);
    //     page_size = vm_hugetlb_size(vm_flags);
    //     if (vm_hugetlb_pageshift(page_size) < 0)
    //         return -EINVAL;
    // }

    /* Careful about overflows.. */
    if (!IS_ALIGNED(len, page_size))
        return -EINVAL;

    if (offset_in(addr, page_size - 1))
        return -EINVAL;

    /* Too many mappings? */
    // if (mm->map_count > sysctl_max_map_count)
    //     return -ENOMEM;

    if (vm_flags & VM_SHARED_USER)
        vm_flags |= VM_SHARED;

    return mmap_region(tsk, mm, addr, len, vm_flags, page_size);
}

static int vm_mmap(struct task_struct *tsk, struct mm_struct *mm,
    unsigned long addr, unsigned long len,
    unsigned long prot, unsigned long flags)
{
    int ret;

    down_write(&mm->mmap_sem);
    ret = do_mmap(tsk, mm, addr, len, prot, flags);
    up_write(&mm->mmap_sem);
    return ret;
}

int __mmap(struct task_struct *tsk, struct mm_struct *mm,
    unsigned long addr, unsigned long len,
    unsigned long prot, unsigned long flags)
{
    return vm_mmap(tsk, mm, addr, len, prot, flags);
}

static int
region_fixup(struct vm_area_struct *vma, struct vm_area_struct **pprev,
    unsigned long start, unsigned long end, unsigned long newflags)
{
    struct mm_struct *mm = vma->vm_mm;
    unsigned long oldflags = vma->vm_flags;
    int error = 0;
    int dirty_accountable = 0;

    if (newflags == oldflags) {
        *pprev = vma;
        return 0;
    }

    *pprev = vma_merge(mm, *pprev, start, end, newflags);
    if (*pprev) {
        vma = *pprev;
        WARN_ON((vma->vm_flags ^ newflags) & ~0);
        goto success;
    }

    *pprev = vma;

    if (start != vma->vm_start) {
        error = split_vma(mm, vma, start, 1);
        if (error)
            goto fail;
    }

    if (end != vma->vm_end) {
        error = split_vma(mm, vma, end, 0);
        if (error)
            goto fail;
    }

success:
    /*
     * vm_flags and vm_page_prot are protected by the mmap_sem
     * held in write mode.
     */
    vma->vm_flags = newflags;
    dirty_accountable = vma_wants_writenotify(vma, vma->vm_page_prot);
    vma_set_page_prot(vma);

    change_pgtable(vma, start, end, vma->vm_page_prot, dirty_accountable);
fail:
    return error;
}

int __mmap_pgprot_modify(struct task_struct *tsk, struct mm_struct *mm, unsigned long addr, unsigned long len,
    unsigned long prot)
{
    unsigned long nstart, tmp, reqprot;
    struct vm_area_struct *vma, *prev;
    int error;
    unsigned long start = addr, end;

    if (start & ~PAGE_MASK)
        return -EINVAL;
    if (!len)
        return 0;
    len = PAGE_ALIGN(len);
    end = start + len;
    if (end <= start)
        return -ENOMEM;
    if (!arch_validate_prot(prot, start))
        return -EINVAL;

    reqprot = prot;

    down_write(&mm->mmap_sem);
    vma = __find_vma(tsk, mm, start);
    error = -ENOMEM;
    if (!vma)
        goto out;

    if (vma->vm_start > start)
        goto out;

    if (start > vma->vm_start)
        prev = vma;
    else
        prev = vma->vm_prev;

    for (nstart = start ; ; ) {
        unsigned long mask_off_old_flags;
        unsigned long newflags;

        mask_off_old_flags = VM_READ | VM_WRITE | VM_EXEC;
        newflags = calc_vm_prot_bits(prot);
        newflags |= (vma->vm_flags & ~mask_off_old_flags);

        tmp = vma->vm_end;
        if (tmp > end)
            tmp = end;
        error = region_fixup(vma, &prev, nstart, tmp, newflags);
        if (error)
            goto out;
        nstart = tmp;

        if (nstart < prev->vm_end)
            nstart = prev->vm_end;
        if (nstart >= end)
            goto out;

        vma = prev->vm_next;
        if (!vma || vma->vm_start != nstart) {
            error = -ENOMEM;
            goto out;
        }
        prot = reqprot;
    }
out:
    up_write(&mm->mmap_sem);
    return error;
}

int expand_stack(struct vm_area_struct *vma, unsigned long address)
{
	return 0;
}
